﻿#include "simplesock.h"

#include <errno.h> //errno
#include <fcntl.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <stdio.h>
#include <string.h>
#include <sys/select.h> //select
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <unistd.h> //close

#include "defcom.h"

SockBase::SockBase() : m_fd(-1), m_err("")
{
}

SockBase::~SockBase()
{
}

bool SockBase::TimeoutRWCheck(int fd, int timeout, bool bRead)
{
    if (fd < 0) {
        return true;
    }
    struct timeval tv;
    tv.tv_sec = timeout / 1000;
    tv.tv_usec = timeout % 1000 * 1000;

    fd_set rset;
    FD_ZERO(&rset);
    FD_SET(fd, &rset);

    int ret = 0;
    if (bRead) {
        ret = select(fd + 1, &rset, NULL, NULL, &tv);
    } else {
        ret = select(fd + 1, NULL, &rset, NULL, &tv);
    }
    return !((ret > 0) && (FD_ISSET(fd, &rset)));
}

int SockBase::Read(void *data, int size, int timeout)
{
    if (timeout < 0) {
        m_err = "time out value error";
        return -1; // 超时时间设置错误
    } else if (timeout > 0) {
        if (TimeoutRWCheck(m_fd, timeout, true) || errno < 0) {
            m_err = "TimeoutRWCheck error";
            return -2; // 超时或发生错误
        }
    }
    char *ptr = (char *)data;
    int nLeft = size;
    while (nLeft > 0) {
        int nRecv = -1;
        if ((nRecv = read(m_fd, ptr, nLeft)) < 0) {
            if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) {
                nRecv = 0;
            } else {
                m_err = strerror(errno);
                return -3;
            }
        } else if (nRecv == 0) {
            break;
        }
        nLeft -= nRecv;
        ptr += nRecv;
    }
    return (size - nLeft);
}

int SockBase::Write(void *data, int size, int timeout)
{
    if (timeout < 0) {
        m_err = "time out value error";
        return -1; // 超时时间设置错误
    } else if (timeout > 0) {
        if (TimeoutRWCheck(m_fd, timeout, false) || errno < 0) {
            return -2; // 超时或发生错误
        }
    }
    char *ptr = (char *)data;
    int nLeft = 0;
    int nSend = 0;
    nLeft = size;
    while (nLeft > 0) {
        if ((nSend = write(m_fd, ptr, nLeft)) <= 0) {
            if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK)
                nSend = 0;
            else {
                m_err = strerror(errno);
                return -3;
            }
        }
        nLeft -= nSend;
        ptr += nSend;
    }
    return (size - nLeft);
}

const char *SockBase::GetErrorInfo()
{
    return m_err.c_str();
}

int SockBase::closefd()
{
    int ret = 0;
    if (m_fd > 0) {
        ret = close(m_fd);
        if (ret != 0) {
            m_err += strerror(errno);
        }
        m_fd = -1;
    }
    return ret;
}

///////////////////////////////////////////////////////////
///////////////////////client//////////////////////////////

ClientSock::ClientSock() : SockBase()
{
}

ClientSock::~ClientSock()
{
}

int ClientSock::SetAsync(int fd, bool bAsync)
{
    int ret = fcntl(fd, F_GETFL, 0);
    if (-1 != ret) {
        if (bAsync) {
            ret = fcntl(fd, F_SETFL, ret | O_NDELAY);
        } else {
            ret = fcntl(fd, F_SETFL, ret & ~O_NDELAY);
        }
    }
    return ret;
}

int ClientSock::CreateConnect(const char *addr, int port, int timeout)
{
    if (timeout < 0) {
        m_err = "time oute value is illegal, must be >=0";
        return -1;
    }
    struct addrinfo hints, *result;
    memset(&hints, 0, sizeof(hints));
    hints.ai_family = AF_INET;
    hints.ai_socktype = SOCK_STREAM;
    char number[10];
    sprintf(number, "%d", port);
    int ret = getaddrinfo(addr, number, &hints, &result);
    if (0 != ret) {
        m_err = gai_strerror(ret);
    } else {
        m_fd = socket(result->ai_family, result->ai_socktype, result->ai_protocol);
        if (-1 == m_fd) {
            m_err = strerror(errno);
        } else {
            int alive = 1; //设置保活开启
            setsockopt(m_fd, SOL_SOCKET, SO_KEEPALIVE, &alive, sizeof(alive));
            int idle = 5; // 5秒钟无数据，触发保活机制，发送保活包
            setsockopt(m_fd, SOL_TCP, TCP_KEEPIDLE, &idle, sizeof(idle));
            int intv = 1; //如果没有收到回应，则5秒钟后重发保活包
            setsockopt(m_fd, SOL_TCP, TCP_KEEPINTVL, &intv, sizeof(intv));
            int cnt = 3; //连续3次没收到保活包，视为连接失效
            setsockopt(m_fd, SOL_TCP, TCP_KEEPCNT, &cnt, sizeof(cnt));
            int opt = 1;
            int len = sizeof(opt);
            setsockopt(m_fd, SOL_SOCKET, SO_REUSEADDR, &opt, len);
            int flag = 1;
            ret = setsockopt(m_fd,          /* socket affected */
                             IPPROTO_TCP,   /* set option at TCP level */
                             TCP_NODELAY,   /* name of option */
                             (char *)&flag, /* the cast is historical cruft */
                             sizeof(int));  /* length of option value */
            ret = SetAsync(m_fd);           //设置非阻塞socket
            if (-1 == ret) {
                m_err = strerror(errno);
                closefd();
            } else {
                if (-1 != connect(m_fd, result->ai_addr, result->ai_addrlen)) { //连接成功
                    ret = SetAsync(m_fd, false); //设置为阻塞
                    if (-1 == ret) {
                        m_err = strerror(errno);
                        closefd();
                    }
                } else {
                    if (errno != EINPROGRESS && errno != EWOULDBLOCK) {
                        m_err = strerror(errno);
                        closefd();
                    } else //处理阻塞，连接已经启动，但是尚未完成
                    {
                        int n, error;
                        socklen_t len;
                        fd_set rset, wset;
                        struct timeval tval;
                        FD_ZERO(&rset);
                        FD_SET(m_fd, &rset);
                        wset = rset;
                        if (timeout == 0) {
                            n = select(m_fd + 1, &rset, &wset, NULL, NULL); //永不超时
                        } else {
                            tval.tv_sec = timeout / 1000;
                            tval.tv_usec = (timeout % 1000) * 1000;
                            n = select(m_fd + 1, &rset, &wset, NULL, &tval); //设置超时时间
                        }
                        if (0 == n) { //超时
                            closefd();
                            errno = ETIMEDOUT;
                            m_err = strerror(errno);
                        }

                        if (FD_ISSET(m_fd, &rset) || FD_ISSET(m_fd, &wset)) {
                            len = sizeof(error);
                            if (getsockopt(m_fd, SOL_SOCKET, SO_ERROR, &error, &len) < 0 ||
                                error) { //处理套接字连接过程中错误
                                m_err = strerror(errno);
                                closefd();
                            } else {
                                if (-1 == SetAsync(m_fd, false)) // 设置为阻塞
                                {
                                    m_err = strerror(errno);
                                    closefd();
                                } else {
                                    if (error) {
                                        errno = error;
                                        m_err = strerror(errno);
                                        closefd();
                                    }
                                }
                            }
                        } else {
                            m_err = strerror(errno);
                            closefd();
                        }
                    }
                }
            }
        }
    }
    freeaddrinfo(result);
    return m_fd;
}

int ClientSock::CloseConnect()
{
    return closefd();
}

///////////////////////////////////////////////////////////
///////////////////////server//////////////////////////////

ServerSock::ServerSock()
{
}

ServerSock::~ServerSock()
{
}

int ServerSock::CreateServer(const int &port)
{
    struct addrinfo hints, *result;
    memset(&hints, 0, sizeof(hints));
    hints.ai_flags = AI_PASSIVE; //被动打开
    hints.ai_family = AF_INET;
    hints.ai_socktype = SOCK_STREAM;
    char number[10];
    sprintf(number, "%d", port);
    int ret = getaddrinfo(NULL, number, &hints, &result);
    if (0 != ret) {
        m_err = gai_strerror(ret);
    } else {
        m_fd = socket(result->ai_family, result->ai_socktype, result->ai_protocol);
        if (-1 == m_fd) {
            m_err = strerror(errno);
        } else {
            int opt = 1;
            int len = sizeof(opt);
            setsockopt(m_fd, SOL_SOCKET, SO_REUSEADDR, &opt, len); //设置地址重用

            if (0 != bind(m_fd, result->ai_addr, result->ai_addrlen)) {
                m_err = strerror(errno);
                closefd();
            } else {
                if (0 != listen(m_fd, MAXBACKLOG)) {
                    m_err = strerror(errno);
                    closefd();
                }
            }
        }
    }
    freeaddrinfo(result);
    return m_fd;
}

int ServerSock::Accept(int timeout, struct sockaddr *addr)
{
    if (timeout > 0) {
        if (TimeoutRWCheck(m_fd, timeout)) {
            m_err = "accept timeout.";
            return -1;
        }
    } else if (timeout < 0) {
        m_err = "timeout value is illegal.";
        return -1;
    }
    struct sockaddr tmpClientAddr;
    struct sockaddr *clientAddr;
    if (NULL == addr) {
        clientAddr = &tmpClientAddr;
    } else {
        clientAddr = addr;
    }

    socklen_t addrlen = sizeof(*clientAddr);
    int ret = -1;
    if (0 < (ret = accept(m_fd, clientAddr, &addrlen))) {
        int rsize = RECV_BUF_SIZE; //接收缓冲区
        int wsize = SEND_BUF_SIZE; //发送缓冲区
        setsockopt(ret, SOL_SOCKET, SO_RCVBUF, &rsize, sizeof(int));
        setsockopt(ret, SOL_SOCKET, SO_SNDBUF, &wsize, sizeof(int));
    } else {
        m_err = strerror(errno);
    }
    return ret;
}
